import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
import random

class ActiveLearning_Framework:
    def __init__(self, train_dataset:data.Dataset, train_dataset_AL:data.Dataset, test_dataset:data.Dataset, num_seed_data,
                 num_class, validation_ratio, device):
        """
        Active Learning framework is responsible for all AL procedures, e.g., update labeled and unlabeled data pools; 
        get the statistics of queried data.
        :param: train_dataset: training dataset
        :param: train_dataset_AL: AL dataset, same as train_dataset, no augmentation
        :param: test_dataset:  testing dataset
        :param: num_seed_data: (int) number of intial seed data
        :param: num_class: (int) number of image class in dataset
        :param: validation_ratio: (float) ratio of data for validation
        :param: device: (torch.device) GPU/CPU 
        """
        self.testset = test_dataset
        self.validation_ratio = validation_ratio
        self.num_class = num_class
        self.Randseed = 0
        self.device = device

        # Random split the Training and Validation datasets
        num_total_data = int(len(train_dataset))
        random.seed(self.Randseed)                   
        train_valid_list = list(range(num_total_data))
        random.shuffle(train_valid_list)
        num_training_data = int(num_total_data * (1 - self.validation_ratio))
        self.trainset = data.Subset(train_dataset, train_valid_list[0:num_training_data])
        self.trainset_AL = data.Subset(train_dataset_AL, train_valid_list[0:num_training_data])
        self.validationset = data.Subset(train_dataset, train_valid_list[num_training_data:])

        # Random split the intial seed data and remaining unlabeled data
        if num_seed_data >= num_training_data:
            raise ValueError('initial label data number exceeds total training data')
        random.seed(self.Randseed)
        seed_data_list = list(range(num_training_data))
        random.shuffle(seed_data_list)
        self.labeled_idx = seed_data_list[0:num_seed_data]
        self.unlabeled_idx = seed_data_list[num_seed_data:]
        self.labeled_pool = data.Subset(self.trainset, self.labeled_idx)
        self.unlabeled_pool = data.Subset(self.trainset, self.unlabeled_idx)
        self.labeled_pool_AL = data.Subset(self.trainset_AL, self.labeled_idx)
        self.unlabeled_pool_AL = data.Subset(self.trainset_AL, self.unlabeled_idx)
        # store the AL queried set of data idx
        self.idx_just_queried = []
        self.weights = torch.ones(len(self.labeled_pool), device=self.device)  #weights for training
        self.optim_results = None


    def get_train_dataset(self):
        return self.labeled_pool

    def get_train_dataset_AL(self):
        return self.labeled_pool_AL

    def get_unlabeled_dataset(self):
        return self.unlabeled_pool

    def get_unlabeled_dataset_AL(self):
        return self.unlabeled_pool_AL

    def get_validation_dataset(self):
        return self.validationset

    def get_test_dataset(self):
        return self.testset

    def Update_AL_Datapool(self, queried_idx, weights = []):
        """
        update labeled and unlabeled data pools 
        :param: queried_idx: (list) index of queried subset of data
        :param: weights: (torch tensor) sparse weights
        """
        self.labeled_idx += queried_idx
        self.idx_just_queried = queried_idx
        for element in queried_idx:
            self.unlabeled_idx.remove(element)
        if len(weights) != 0:
            #self.weights = torch.cat([torch.ones(len(self.weights), device=self.device), weights])  # use weights
            self.weights = torch.cat([torch.ones(len(self.weights), device=self.device), torch.ones(len(weights), device=self.device)])  # not use weights
        self.labeled_pool = data.Subset(self.trainset, self.labeled_idx)
        self.unlabeled_pool = data.Subset(self.trainset, self.unlabeled_idx)

    def get_labels_just_moved(self):
        """
        return the labels of queried subset of data
        """
        dataset = data.Subset(self.trainset, self.idx_just_queried)
        loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        labels = next(iter(loader))[1].tolist()
        return labels

    def get_labels_just_moved_stats(self):
        """
        return the statistics of queried subset of data, i.e., class count
        """
        label_stats = {}
        dataset = data.Subset(self.trainset, self.idx_just_queried)
        for i in range(len(dataset)):
            _, label = dataset[i]
            if label not in label_stats:
                label_stats[label] = 1
            else:
                label_stats[label] += 1
        return label_stats


class idx_Dataset(torch.utils.data.Dataset):
    def __init__(self, length):
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return idx
